import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from llava_v15.conversation import conv_llava_v1, SeparatorStyle
import matplotlib.pyplot as plt

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords, tokenizer, input_ids):
        self.keywords = keywords
        self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
        self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
        self.tokenizer = tokenizer
        self.start_len = None
        self.input_ids = input_ids

    def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if self.start_len is None:
            self.start_len = self.input_ids.shape[1]
        else:
            for keyword_id in self.keyword_ids:
                if output_ids[0, -1] == keyword_id:
                    return True
            outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
            for keyword in self.keywords:
                if keyword in outputs:
                    return True
        return False


class Generator:

    def __init__(self, model, tokenizer, max_new_tokens=1024, temperature=0.2, device='cuda:0'):

        self.model = model
        self.device = device
        self.tokenizer = tokenizer

        self.max_new_tokens = max_new_tokens
        self.temperature = temperature

        self.stop_str = conv_llava_v1.sep if conv_llava_v1.sep_style != SeparatorStyle.TWO else conv_llava_v1.sep2
        self.keywords = [self.stop_str]


    def generate(self, prompt, image, max_new_tokens=1024):

        # input_ids = prompt.input_ids[0]
        input_ids = prompt

        stopping_criteria = KeywordsStoppingCriteria(self.keywords, self.tokenizer, input_ids)

        # with torch.inference_mode():
            # output_ids = self.model.generate(
            #     input_ids,
            #     images=image.to(dtype=torch.float16, device='cuda', non_blocking=True),
            #     do_sample=True,
            #     temperature=0.2,
            #     max_new_tokens=max_new_tokens,
            #     use_cache=True,
            #     stopping_criteria=[stopping_criteria])
        
        with torch.inference_mode():
            # output_dict = self.model.generate(
            #     input_ids,
            #     images=image.to(dtype=torch.float16, device='cuda', non_blocking=True),
            #     do_sample=True,
            #     temperature=0.2,
            #     max_new_tokens=max_new_tokens,
            #     use_cache=True,
            #     stopping_criteria=[stopping_criteria],
            #     return_dict_in_generate=True,
            #     output_attentions=True)
            output_ids = self.model.generate(
                input_ids,
                images=image.to(dtype=torch.float16, device='cuda', non_blocking=True),
                do_sample=True,
                temperature=0.2,
                max_new_tokens=max_new_tokens,
                use_cache=True,
                stopping_criteria=[stopping_criteria])

        # input_token_len = input_ids.shape[1]
        # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        # if n_diff_input_output > 0:
        #     print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
        # outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        # outputs = outputs.strip()
        # if outputs.endswith(self.stop_str):
        #     outputs = outputs[:-len(self.stop_str)]
        # outputs = outputs.strip()

        outputs = self.tokenizer.batch_decode(output_ids[:, :], skip_special_tokens=True)[0]
        outputs = outputs.strip()
        if outputs.endswith(self.stop_str):
            outputs = outputs[:-len(self.stop_str)]
        outputs = outputs.strip()

        return outputs
        # return output_dict
    
    def generate_attention(self, prompt, image, max_new_tokens=1024):

        input_ids = prompt.input_ids[0]

        stopping_criteria = KeywordsStoppingCriteria(self.keywords, self.tokenizer, input_ids)

        # with torch.inference_mode():
        #     output_ids = self.model.generate(
        #         input_ids,
        #         images=image.to(dtype=torch.float16, device='cuda', non_blocking=True),
        #         do_sample=True,
        #         temperature=0.2,
        #         max_new_tokens=max_new_tokens,
        #         use_cache=True,
        #         stopping_criteria=[stopping_criteria])
        
        with torch.inference_mode():
            # outputs = self.model.generate(
            #     input_ids,
            #     images=image.to(dtype=torch.float16, device='cuda', non_blocking=True),
            #     do_sample=True,
            #     temperature=0.2,
            #     max_new_tokens=max_new_tokens,
            #     use_cache=True,
            #     stopping_criteria=[stopping_criteria],
            #     return_dict_in_generate=True,
            #     output_attentions=True)
            outputs = self.model(
                input_ids,
                images=image.to(dtype=torch.float16, device='cuda', non_blocking=True),
                use_cache=True,
                stopping_criteria=[stopping_criteria],
                return_dict_in_generate=True,
                output_attentions=True)
            
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                return_dict=True,
                labels=labels,
                images=images.half(),
            )

        # input_token_len = input_ids.shape[1]
        # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
        # if n_diff_input_output > 0:
        #     print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
        # outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
        # outputs = outputs.strip()
        # if outputs.endswith(self.stop_str):
        #     outputs = outputs[:-len(self.stop_str)]
        # outputs = outputs.strip()

        # outputs = self.tokenizer.batch_decode(output_ids[:, :], skip_special_tokens=True)[0]
        # outputs = outputs.strip()
        # if outputs.endswith(self.stop_str):
        #     outputs = outputs[:-len(self.stop_str)]
        # outputs = outputs.strip()

        # return outputs
        input_token_len = self.model.get_vision_tower().num_patches + len(input_ids[0]) - 1 # -1 for the <image> token
        vision_token_start = len(self.tokenizer(prompt.text_prompts.split("<image>")[0], return_tensors='pt')["input_ids"][0])
        vision_token_end = vision_token_start + self.model.get_vision_tower().num_patches
        output_token_len = len(outputs["sequences"][0])
        output_token_start = input_token_len
        output_token_end = input_token_len + output_token_len

        attentions = outputs["attentions"]

        attention_matrix = attentions[-1][-2][:,0,:,:]
        attention_matrix = attentions[-1][-1].mean(dim=1)
        print(attentions[-2][-2].shape)
        attention_matrix_np = attention_matrix.squeeze(0).cpu().numpy()

        # Plot and save as temp.png
        plt.figure(figsize=(8, 8))
        plt.imshow(attention_matrix_np, cmap='viridis', interpolation='nearest',vmax=0.1)
        plt.colorbar(label='Attention Weight')
        plt.title('Attention Matrix')
        plt.savefig("temp.png")
        plt.close()

        print("Saved as temp.png")

        return outputs